{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "5147e458-3b83-449e-9c2f-e7e1972e43fc",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "# Databricks\n",
    "\n",
    "The [Databricks](https://www.databricks.com/) Lakehouse Platform unifies data, analytics, and AI on one platform.\n",
    "\n",
    "This example notebook shows how to wrap Databricks endpoints as LLMs in LangChain.\n",
    "It supports two endpoint types:\n",
    "* Serving endpoint, recommended for production and development,\n",
    "* Cluster driver proxy app, recommended for iteractive development."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "bf07455f-aac9-4873-a8e7-7952af0f8c82",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "from langchain.llms import Databricks"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "94f6540e-40cd-4d9b-95d3-33d36f061dcc",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "## Wrapping a serving endpoint\n",
    "\n",
    "Prerequisites:\n",
    "* An LLM was registered and deployed to [a Databricks serving endpoint](https://docs.databricks.com/machine-learning/model-serving/index.html).\n",
    "* You have [\"Can Query\" permission](https://docs.databricks.com/security/auth-authz/access-control/serving-endpoint-acl.html) to the endpoint.\n",
    "\n",
    "The expected MLflow model signature is:\n",
    "  * inputs: `[{\"name\": \"prompt\", \"type\": \"string\"}, {\"name\": \"stop\", \"type\": \"list[string]\"}]`\n",
    "  * outputs: `[{\"type\": \"string\"}]`\n",
    "\n",
    "If the model signature is incompatible or you want to insert extra configs, you can set `transform_input_fn` and `transform_output_fn` accordingly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "7496dc7a-8a1a-4ce6-9648-4f69ed25275b",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'I am happy to hear that you are in good health and as always, you are appreciated.'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# If running a Databricks notebook attached to an interactive cluster in \"single user\"\n",
    "# or \"no isolation shared\" mode, you only need to specify the endpoint name to create\n",
    "# a `Databricks` instance to query a serving endpoint in the same workspace.\n",
    "llm = Databricks(endpoint_name=\"dolly\")\n",
    "\n",
    "llm(\"How are you?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "0c86d952-4236-4a5e-bdac-cf4e3ccf3a16",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Good'"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "llm(\"How are you?\", stop=[\".\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "5f2507a2-addd-431d-9da5-dc2ae33783f6",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'I am fine. Thank you!'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Otherwise, you can manually specify the Databricks workspace hostname and personal access token\n",
    "# or set `DATABRICKS_HOST` and `DATABRICKS_TOKEN` environment variables, respectively.\n",
    "# See https://docs.databricks.com/dev-tools/auth.html#databricks-personal-access-tokens\n",
    "# We strongly recommend not exposing the API token explicitly inside a notebook.\n",
    "# You can use Databricks secret manager to store your API token securely.\n",
    "# See https://docs.databricks.com/dev-tools/databricks-utils.html#secrets-utility-dbutilssecrets\n",
    "\n",
    "import os\n",
    "\n",
    "os.environ[\"DATABRICKS_TOKEN\"] = dbutils.secrets.get(\"myworkspace\", \"api_token\")\n",
    "\n",
    "llm = Databricks(host=\"myworkspace.cloud.databricks.com\", endpoint_name=\"dolly\")\n",
    "\n",
    "llm(\"How are you?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "9b54f8ce-ffe5-4c47-a3f0-b4ebde524a6a",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'I am fine.'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# If the serving endpoint accepts extra parameters like `temperature`,\n",
    "# you can set them in `model_kwargs`.\n",
    "llm = Databricks(endpoint_name=\"dolly\", model_kwargs={\"temperature\": 0.1})\n",
    "\n",
    "llm(\"How are you?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "50f172f5-ea1f-4ceb-8cf1-20289848de7b",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'I’m Excellent. You?'"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Use `transform_input_fn` and `transform_output_fn` if the serving endpoint\n",
    "# expects a different input schema and does not return a JSON string,\n",
    "# respectively, or you want to apply a prompt template on top.\n",
    "\n",
    "\n",
    "def transform_input(**request):\n",
    "    full_prompt = f\"\"\"{request[\"prompt\"]}\n",
    "    Be Concise.\n",
    "    \"\"\"\n",
    "    request[\"prompt\"] = full_prompt\n",
    "    return request\n",
    "\n",
    "\n",
    "llm = Databricks(endpoint_name=\"dolly\", transform_input_fn=transform_input)\n",
    "\n",
    "llm(\"How are you?\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "8ea49319-a041-494d-afcd-87bcf00d5efb",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "## Wrapping a cluster driver proxy app\n",
    "\n",
    "Prerequisites:\n",
    "* An LLM loaded on a Databricks interactive cluster in \"single user\" or \"no isolation shared\" mode.\n",
    "* A local HTTP server running on the driver node to serve the model at `\"/\"` using HTTP POST with JSON input/output.\n",
    "* It uses a port number between `[3000, 8000]` and listens to the driver IP address or simply `0.0.0.0` instead of localhost only.\n",
    "* You have \"Can Attach To\" permission to the cluster.\n",
    "\n",
    "The expected server schema (using JSON schema) is:\n",
    "* inputs:\n",
    "  ```json\n",
    "  {\"type\": \"object\",\n",
    "   \"properties\": {\n",
    "      \"prompt\": {\"type\": \"string\"},\n",
    "       \"stop\": {\"type\": \"array\", \"items\": {\"type\": \"string\"}}},\n",
    "    \"required\": [\"prompt\"]}\n",
    "  ```\n",
    "* outputs: `{\"type\": \"string\"}`\n",
    "\n",
    "If the server schema is incompatible or you want to insert extra configs, you can use `transform_input_fn` and `transform_output_fn` accordingly.\n",
    "\n",
    "The following is a minimal example for running a driver proxy app to serve an LLM:\n",
    "\n",
    "```python\n",
    "from flask import Flask, request, jsonify\n",
    "import torch\n",
    "from transformers import pipeline, AutoTokenizer, StoppingCriteria\n",
    "\n",
    "model = \"databricks/dolly-v2-3b\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model, padding_side=\"left\")\n",
    "dolly = pipeline(model=model, tokenizer=tokenizer, trust_remote_code=True, device_map=\"auto\")\n",
    "device = dolly.device\n",
    "\n",
    "class CheckStop(StoppingCriteria):\n",
    "    def __init__(self, stop=None):\n",
    "        super().__init__()\n",
    "        self.stop = stop or []\n",
    "        self.matched = \"\"\n",
    "        self.stop_ids = [tokenizer.encode(s, return_tensors='pt').to(device) for s in self.stop]\n",
    "    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):\n",
    "        for i, s in enumerate(self.stop_ids):\n",
    "            if torch.all((s == input_ids[0][-s.shape[1]:])).item():\n",
    "                self.matched = self.stop[i]\n",
    "                return True\n",
    "        return False\n",
    "\n",
    "def llm(prompt, stop=None, **kwargs):\n",
    "  check_stop = CheckStop(stop)\n",
    "  result = dolly(prompt, stopping_criteria=[check_stop], **kwargs)\n",
    "  return result[0][\"generated_text\"].rstrip(check_stop.matched)\n",
    "\n",
    "app = Flask(\"dolly\")\n",
    "\n",
    "@app.route('/', methods=['POST'])\n",
    "def serve_llm():\n",
    "  resp = llm(**request.json)\n",
    "  return jsonify(resp)\n",
    "\n",
    "app.run(host=\"0.0.0.0\", port=\"7777\")\n",
    "```\n",
    "\n",
    "Once the server is running, you can create a `Databricks` instance to wrap it as an LLM."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "e3330a01-e738-4170-a176-9954aff56442",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Hello, thank you for asking. It is wonderful to hear that you are well.'"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# If running a Databricks notebook attached to the same cluster that runs the app,\n",
    "# you only need to specify the driver port to create a `Databricks` instance.\n",
    "llm = Databricks(cluster_driver_port=\"7777\")\n",
    "\n",
    "llm(\"How are you?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "39c121cf-0e44-4e31-91db-37fcac459677",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'I am well. You?'"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Otherwise, you can manually specify the cluster ID to use,\n",
    "# as well as Databricks workspace hostname and personal access token.\n",
    "\n",
    "llm = Databricks(cluster_id=\"0000-000000-xxxxxxxx\", cluster_driver_port=\"7777\")\n",
    "\n",
    "llm(\"How are you?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "3d3de599-82fd-45e4-8d8b-bacfc49dc9ce",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'I am very well. It is a pleasure to meet you.'"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# If the app accepts extra parameters like `temperature`,\n",
    "# you can set them in `model_kwargs`.\n",
    "llm = Databricks(cluster_driver_port=\"7777\", model_kwargs={\"temperature\": 0.1})\n",
    "\n",
    "llm(\"How are you?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "853fae8e-8df4-41e6-9d45-7769f883fe80",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'I AM DOING GREAT THANK YOU.'"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Use `transform_input_fn` and `transform_output_fn` if the app\n",
    "# expects a different input schema and does not return a JSON string,\n",
    "# respectively, or you want to apply a prompt template on top.\n",
    "\n",
    "\n",
    "def transform_input(**request):\n",
    "    full_prompt = f\"\"\"{request[\"prompt\"]}\n",
    "    Be Concise.\n",
    "    \"\"\"\n",
    "    request[\"prompt\"] = full_prompt\n",
    "    return request\n",
    "\n",
    "\n",
    "def transform_output(response):\n",
    "    return response.upper()\n",
    "\n",
    "\n",
    "llm = Databricks(\n",
    "    cluster_driver_port=\"7777\",\n",
    "    transform_input_fn=transform_input,\n",
    "    transform_output_fn=transform_output,\n",
    ")\n",
    "\n",
    "llm(\"How are you?\")"
   ]
  }
 ],
 "metadata": {
  "application/vnd.databricks.v1+notebook": {
   "dashboards": [],
   "language": "python",
   "notebookMetadata": {
    "pythonIndentUnit": 2
   },
   "notebookName": "databricks",
   "widgets": {}
  },
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
